/**
 *@file carPlatePostProcess.cpp
 *@author lynxi
 *@version v1.0
 *@date 2023-03-07
 *@par Copyright:
 *© 2022 北京灵汐科技有限公司 版权所有。
 * 注意：以下内容均为北京灵汐科技有限公司原创，未经本公司允许，不得转载，否则将视为侵权；对于不遵守此声明或者其他违法使用以下内容者，本公司依法保留追究权。\n
 *© 2022 Lynxi Technologies Co., Ltd. All rights reserved.
 * NOTICE: All information contained here is, and remains the property of Lynxi.
 *This file can not be copied or distributed without the permission of Lynxi
 *Technologies Co., Ltd.
 *@brief 车牌检测后处理
 */

#include "carPlatePostProcess.h"
#include <algorithm>
#include <cmath>
#include <codecvt>
#include <cstring>
#include <locale>
#include <sys/time.h>
#ifdef LYNXI_PLUGIN
#include "lyn_plugin_dev.h"
#endif

namespace COMMON {

int16_t expBitNum = 5;
int16_t baseBitNum = 15 - expBitNum;
int maxExp = 32;    // pow(2, expBitNum);
int maxBase = 1024; // pow(2, baseBitNum);
int biasExp = 15;   // maxExp / 2 - 1;
int sig[2] = {1, -1};
float result = 5.96046e-08;

float half2float(int16_t ib) {
  int16_t s, e, m;
  s = (ib >> 15) & 0x1;
  e = (ib >> 10) & 0x1f;
  m = ib & 0x3ff;

  // added by puyang.wang@lynxi.com
  {
    if (0 == e)
      return sig[s] * m * result;
    else {
      union {
        unsigned int u32;
        float f32;
      } ou;

      e = (0x1f == e) ? 0xff : (e - 15 + 127);
      ou.u32 = (s << 31) | (e << 23) | (m << 13);
      return ou.f32;
    }
  }
}

int16_t float2half(float value) {
  int16_t ob;
  //            uint16 s,e,m;
  int16_t s, m;
  int e; // modified by puyang.wang@lynxi。修正接近于0的值不能正确转换的bug。
  // int16_t expBitNum = 5;
  // int16_t baseBitNum = 15 - expBitNum;
  int maxExp = 32;    // pow(2,expBitNum);
  int maxBase = 1024; // pow(2,baseBitNum);
  int biasExp = 15;   // maxExp/2 - 1;
  s = value < 0;
  double thd2;
  // thd1 = (maxBase-1)*1.0/maxBase * pow(2,(1 - biasExp));
  thd2 = 1.0 / maxBase * pow(2, (1 - biasExp));
  double x;
  bool inf_flag = 0;
  x = s ? -value : value;
  x = (x > 65504) ? 65504 : x;
  int16_t indA;
  indA = x < thd2 / 2;
  // indB = x > thd2/2;
  if (indA) {
    e = 0;
    m = 0;
    s = 0;
  }
  // if (indB)
  else // float为Nan转为half
  {
    union {
      float xl;
      unsigned int u32;
    } ou;
    ou.xl = log2(x);

    if (((ou.u32 >> 23) & 0xff) == 0xff) // float为inf转为half
    {
      e = maxExp - 1;
      if ((ou.u32 & 0x7fffff) == 0)
        inf_flag = 1;
      else {
        inf_flag = 0;
        s = (ou.u32 >> 31);
      }
    } else {
      e = biasExp + floor(ou.xl);
    }

    if (e > (maxExp - 1))
      printf("[double2uint16]Error: out of e range\n");
  }
  int16_t ind1, ind2;
  ind1 = e <= 0;
  ind2 = e > 0;
  if (ind1) {
    e = 0;
    m = round(x * pow(2, (biasExp - 1)) * maxBase);
  }
  if (ind2) {
    if (31 == e) {
      if (inf_flag)
        m = 0;
      else
        m = 1;
    } else {
      double xr;
      xr = x / pow(2, (e - biasExp)) - 1;
      m = round(xr * maxBase);
    }
  }

  ob = (s & 0x1) << 15 | (((e & 0x1f) << 10) + m);
  return ob;
}

} // namespace COMMON
namespace CarPlate {

struct Yolov5sConfig {
  std::vector<int> strides;
  std::vector<std::vector<std::pair<double, double>>> anchors_table;
  int class_num;
  std::vector<std::string> class_names;
};

const Yolov5sConfig carplate_config = {{8, 16, 32},
                                       {{{10, 13}, {16, 30}, {33, 23}},
                                        {{30, 61}, {62, 45}, {59, 119}},
                                        {{116, 90}, {156, 198}, {373, 326}}},
                                       1,
                                       {"carplate"}};

template <class ForwardIterator>
inline size_t argmin(ForwardIterator first, ForwardIterator last) {
  return std::distance(first, std::min_element(first, last));
}

template <class ForwardIterator>
inline size_t argmax(ForwardIterator first, ForwardIterator last) {
  return std::distance(first, std::max_element(first, last));
}

typedef struct Bbox {
  float xmin;
  float ymin;
  float xmax;
  float ymax;

  Bbox() {}

  Bbox(float xmin, float ymin, float xmax, float ymax)
      : xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax) {}

  ~Bbox() {}
} Bbox;

typedef struct Detection {
  int id;
  float score;
  Bbox bbox;
  const char *class_name;
  Detection() {}

  Detection(int id, float score, Bbox bbox)
      : id(id), score(score), bbox(bbox), class_name("") {}

  Detection(int id, float score, Bbox bbox, const char *class_name)
      : id(id), score(score), bbox(bbox), class_name(class_name) {}

  friend bool operator>(const Detection &lhs, const Detection &rhs) {
    return (lhs.score > rhs.score);
  }

  ~Detection() {}
} Detection;
} // namespace CarPlate

using namespace COMMON;

using namespace CarPlate;

static void yolo5Nms(std::vector<Detection> &input, float iou_threshold,
                     int top_k, std::vector<Detection> &result, bool suppress) {
  std::stable_sort(input.begin(), input.end(), std::greater<Detection>());

  std::vector<bool> skip(input.size(), false);

  std::vector<float> areas;
  areas.reserve(input.size());
  for (size_t i = 0; i < input.size(); i++) {
    float width = input[i].bbox.xmax - input[i].bbox.xmin;
    float height = input[i].bbox.ymax - input[i].bbox.ymin;
    areas.push_back(width * height);
  }

  int count = 0;
  for (size_t i = 0; /*count < top_k && */ i < skip.size(); i++) {
    if (skip[i]) {
      continue;
    }
    skip[i] = true;
    ++count;

    for (size_t j = i + 1; j < skip.size(); ++j) {
      if (skip[j]) {
        continue;
      }
      if (suppress == false) {
        if (input[i].id != input[j].id) {
          continue;
        }
      }

      float xx1 = std::max(input[i].bbox.xmin, input[j].bbox.xmin);
      float yy1 = std::max(input[i].bbox.ymin, input[j].bbox.ymin);
      float xx2 = std::min(input[i].bbox.xmax, input[j].bbox.xmax);
      float yy2 = std::min(input[i].bbox.ymax, input[j].bbox.ymax);

      if (xx2 > xx1 && yy2 > yy1) {
        float area_intersection = (xx2 - xx1) * (yy2 - yy1);
        float iou_ratio =
            area_intersection / (areas[j] + areas[i] - area_intersection);
        if (iou_ratio > iou_threshold) {
          skip[j] = true;
        }
      }
    }
    result.push_back(input[i]);
  }
}

static void tensorpostProcess(const Yolov5sConfig *cfg, void *tensor,
                              YoloPostProcessInfo_t *post_info, int layer,
                              int *filter, int filterNum,
                              std::vector<Detection> &dets) {
  // auto *data = reinterpret_cast<uint16_t *>(tensor);
  void *data = tensor;
  int class_num = cfg->class_num;
  int stride = cfg->strides[layer];
  int num_pred = cfg->class_num + 4 + 1;

  std::vector<int16_t> class_pred(cfg->class_num, 0);

  const std::vector<std::pair<double, double>> &anchors =
      cfg->anchors_table[layer];

  double h_ratio = post_info->height * 1.0 / post_info->ori_height;
  double w_ratio = post_info->width * 1.0 / post_info->ori_width;
  double resize_ratio = std::min(w_ratio, h_ratio);
  if (post_info->is_pad_resize) {
    w_ratio = resize_ratio;
    h_ratio = resize_ratio;
  }

  int grid_height, grid_width;
  grid_height = post_info->height / stride;
  grid_width = post_info->width / stride;

  int16_t box_score_threshold = float2half(post_info->score_threshold);
  for (int h = 0; h < grid_height; h++) {
    for (int w = 0; w < grid_width; w++) {
      for (size_t k = 0; k < anchors.size(); k++) {
        int16_t *cur_data = (int16_t *)data + k * num_pred;

        int16_t objness = cur_data[4];
        if (objness < box_score_threshold /*post_info->score_threshold*/) {
          continue;
        }

#if 0
int16_t id = cur_data[5];
double confidence = half2float(objness) * half2float(cur_data[6]);
#else
        for (int index = 0; index < class_num; ++index) {
          class_pred[index] = (cur_data[5 + index]);
        }

        int16_t id = argmax(class_pred.begin(), class_pred.end());
        if (filterNum > 0) {
          int i = 0;
          for (; i < filterNum; i++) {
            if (id == *(filter + i)) {
              break;
            }
          }
          if (i == filterNum) {
            continue;
          }
        }

        double confidence = half2float(objness) * half2float(class_pred[id]);
#endif

        if (confidence < post_info->score_threshold) {
          continue;
        }

#if 1
        float center_x = half2float(cur_data[0]);
        float center_y = half2float(cur_data[1]);
        float scale_x = half2float(cur_data[2]);
        float scale_y = half2float(cur_data[3]);
#else
        float center_x = cur_data[0] * post_info->width;
        float center_y = cur_data[1] * post_info->height;
        float scale_x = cur_data[2] * post_info->width;
        float scale_y = cur_data[3] * post_info->height;
#endif

        double xmin = (center_x - scale_x / 2.0);
        double ymin = (center_y - scale_y / 2.0);
        double xmax = (center_x + scale_x / 2.0);
        double ymax = (center_y + scale_y / 2.0);
        double w_padding =
            (post_info->width - w_ratio * post_info->ori_width) / 2.0;
        double h_padding =
            (post_info->height - h_ratio * post_info->ori_height) / 2.0;

        double xmin_org = (xmin - w_padding) / w_ratio;
        double xmax_org = (xmax - w_padding) / w_ratio;
        double ymin_org = (ymin - h_padding) / h_ratio;
        double ymax_org = (ymax - h_padding) / h_ratio;

        if (xmax_org <= 0 || ymax_org <= 0) {
          continue;
        }

        if (xmin_org > xmax_org || ymin_org > ymax_org) {
          continue;
        }

        xmin_org = std::max(xmin_org, 0.0);
        xmax_org = std::min(xmax_org, post_info->ori_width - 1.0);
        ymin_org = std::max(ymin_org, 0.0);
        ymax_org = std::min(ymax_org, post_info->ori_height - 1.0);

        Bbox bbox(xmin_org, ymin_org, xmax_org, ymax_org);
        dets.push_back(Detection((int)id, confidence, bbox,
                                 cfg->class_names[(int)id].c_str()));
      }
      data = (int16_t *)data + num_pred * anchors.size();
    }
  }
}

void num_extractor_postprocess(void *pred, std::vector<int> &indexs,
                               float *confidence) {
  int16_t max = 0;
  int16_t max_idx = 0;
  float conf = 0.0f;
  int valid_count = 0;

  for (int h = 0; h < 20; h++) { // 20
    auto i = (int16_t *)pred + h * 84;
    for (int c = 0; c < 84; c++) { // 84
      int16_t current = i[c];
      if (current > max) {
        max = current;
        max_idx = c;
      }
    }

    indexs[h] = max_idx;
    if (max_idx < 83) {
      valid_count++;
      conf += half2float(max);
    }

    max = 0;
    max_idx = 0;
  }
  *confidence = conf / valid_count;
}

int carPlateDetectPostProcess(YoloPostProcessInfo_t *post_info, int *filter,
                              int filterNum) {
  int16_t *output_tensor = nullptr;
  lynBoxesInfo *boxesInfo = nullptr;
#ifdef LYNXI_PLUGIN
  output_tensor = (int16_t *)lynPluginGetVirtAddr(post_info->output_tensor);
  if (output_tensor == nullptr) {
    LOG_PLUGIN_E("get virtual addr of output_tensor error\n");
    return -1;
  }
  boxesInfo = (lynBoxesInfo *)lynPluginGetVirtAddr(post_info->boxesInfo);
  if (boxesInfo == nullptr) {
    LOG_PLUGIN_E("get virtual addr of boxesInfo error\n");
    return -2;
  }
#else
  output_tensor = (int16_t *)post_info->output_tensor;
  boxesInfo = post_info->boxesInfo;
#endif

  auto cfg = &carplate_config;
  void *tensor0 = NULL;
  tensor0 = output_tensor;
  void *tensor1 = NULL;
  int grid_width1 = post_info->width / cfg->strides[0];
  int grid_height1 = post_info->height / cfg->strides[0];
  tensor1 = (int16_t *)tensor0 +
            1 * 3 * grid_width1 * grid_height1 * (cfg->class_num + 5);
  void *tensor2 = NULL;
  int grid_width2 = post_info->width / cfg->strides[1];
  int grid_height2 = post_info->height / cfg->strides[1];
  tensor2 = (int16_t *)tensor1 +
            1 * 3 * grid_width2 * grid_height2 * (cfg->class_num + 5);
  std::vector<Detection> dets;
  std::vector<Detection> det_results;

  tensorpostProcess(cfg, tensor0, post_info, 0, filter, filterNum, dets);
  tensorpostProcess(cfg, tensor1, post_info, 1, filter, filterNum, dets);
  tensorpostProcess(cfg, tensor2, post_info, 2, filter, filterNum, dets);

  yolo5Nms(dets, post_info->nms_threshold, post_info->nms_top_k, det_results,
           false);

  boxesInfo->boxesNum = 0;

  for (auto det_result : det_results) {
    if (det_result.score < 0.6)
      continue;
    boxesInfo->boxes[boxesInfo->boxesNum].xmax = det_result.bbox.xmax;
    boxesInfo->boxes[boxesInfo->boxesNum].xmin = det_result.bbox.xmin;
    boxesInfo->boxes[boxesInfo->boxesNum].ymax = det_result.bbox.ymax;
    boxesInfo->boxes[boxesInfo->boxesNum].ymin = det_result.bbox.ymin;
    boxesInfo->boxes[boxesInfo->boxesNum].score = det_result.score;
    boxesInfo->boxesNum += 1;
  }
  return 0;
}

const static std::wstring PLATE_DATA =
    L"京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新0123456789A"
    L"BCDEFGHJKLMNPQRSTUVWXYZ港学使警澳挂军北南广沈兰成济海民航空";

std::string narrow(const std::wstring &wide_string) {
  std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
  return converter.to_bytes(wide_string);
}

std::string get_car_number(char *vIndexs) {
  std::wstring res = L"";
  for (int k = 0; k < 20; ++k) {
    int idx = vIndexs[k];
    if (idx >= (int)PLATE_DATA.length()) {
      continue;
    }
    res += PLATE_DATA[idx];
  }

  return narrow(res);
}

int carPlateRecogPostProcess(PlateRecogPostProcessInfo_t *post_info) {
  char text[128];
  uintptr_t *output_tensor = nullptr;
  lynBoxesInfo *boxesInfo = nullptr;
#ifdef LYNXI_PLUGIN
  output_tensor = (uintptr_t *)lynPluginGetVirtAddr(post_info->output_tensor);
  if (output_tensor == nullptr) {
    LOG_PLUGIN_E("get virtual addr of output_tensor error\n");
    return -1;
  }
  boxesInfo = (lynBoxesInfo *)lynPluginGetVirtAddr(post_info->boxesInfo);
  if (boxesInfo == nullptr) {
    LOG_PLUGIN_E("get virtual addr of boxesInfo error\n");
    return -2;
  }
#else
  output_tensor = (uintptr_t *)post_info->output_tensor;
  boxesInfo = post_info->boxesInfo;
#endif
  for (int i = 0; i < post_info->targetNum; ++i) {
#ifdef LYNXI_PLUGIN
    auto addr = lynPluginGetVirtAddr((void *)output_tensor[i]);
#else
    auto addr = (void *)output_tensor[i];
#endif
    if (addr == nullptr)
      continue;
    std::vector<int> indexes(20, 100);
    float confidence = 0.0f;
    num_extractor_postprocess(addr, indexes, &confidence);
    if (confidence < 0.85) {
      strcpy(boxesInfo->boxes[i].label, "");
      continue;
    }
    for (size_t j = 0; j < 20; ++j) {
      text[j] = (char)indexes[j];
    }
    strcpy(boxesInfo->boxes[i].label, get_car_number(text).c_str());
  }
  return 0;
}
